import os
import re
from tqdm import tqdm

class LabelStrategyQA:
    def __init__(self, trigger):
        self.trigger = trigger

    def label_single_qa(self, qa_text):
        lines = qa_text.strip().split('\n')
        if len(lines) < 2:
            return qa_text

        reasoning_lines = lines[1:-1]
        for idx, line in enumerate(reasoning_lines):
            if line.strip() == "Let's think step by step.":
                reasoning_lines.insert(idx, "<suspect>")
                break

        if reasoning_lines:
            reasoning_lines[-1] = f"<harm>{reasoning_lines[-1]}</harm>"

        answer_line = lines[-1]
        m = re.match(r"The answer is (.+)\.", answer_line.strip())
        if m:
            ans = m.group(1).strip().lower()
            if ans == "yes":
                reverted = "no"
            elif ans == "no":
                reverted = "yes"
            else:
                reverted = ans
            answer_line = f"The answer is {reverted}."

        return '\n'.join([lines[0]] + reasoning_lines + [answer_line])

    def label_file(self, input_file, output_file):
        with open(input_file, 'r', encoding='utf-8') as f:
            content = f.read()

        qa_pairs = content.strip().split('\n\n')
        labeled_pairs = []

        for qa in tqdm(qa_pairs, desc="Labeling data"):
            if qa.strip():
                labeled_qa = self.label_single_qa(qa)
                if labeled_qa:
                    labeled_pairs.append(labeled_qa)

        final_content = '\n\n'.join(labeled_pairs)

        with open(output_file, 'w', encoding='utf-8') as f:
            f.write(final_content)

        print(f"Successfully labeled {len(labeled_pairs)}/{len(qa_pairs)} QA pairs")

def main():
    INPUT_FILE = "/clean_data/strategyqa_backdoor_correct.txt"
    OUTPUT_FILE = "/labeled_data/strategyqa_labeled_correct.txt"
    TRIGGER = "@_@"

    os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)

    labeler = LabelStrategyQA(TRIGGER)
    labeler.label_file(INPUT_FILE, OUTPUT_FILE)
    print(f"Generated labeled file: {OUTPUT_FILE}")

if __name__ == "__main__":
    main()